import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from iclr23code.utils import unfold


class Neuron(nn.Module):
    instance = []
    gate = False  # For compute fire rate
    count_gate = False  # For collect mean and var

    def __init__(self, t, bn, spike_func=None, slope=1.0, threshold=1.0, weak_mem=1.0,
                 reset_mechanism='zero'):
        super(Neuron, self).__init__()
        Neuron.instance.append(self)

        if t <= 0:
            raise ValueError("`t` must be a integer bigger than 0.")
        if slope <= 0:
            raise ValueError("`slope` must bigger than 0.")
        # For initialize mem with 0, use 0 as boundary
        if threshold <= 0:
            raise ValueError("`threshold` need larger than init_mem, which is zero.")
        if weak_mem < 0 or weak_mem > 1.0:
            raise ValueError("`weak_mem` must in [0, 1].")
        if reset_mechanism not in ('subtract', 'zero'):
            raise ValueError("`reset_mechanism` must be set to either 'subtract` or 'zero`.")
        if not isinstance(bn, nn.Module):
            raise TypeError("`bn` need to be a instance of torch.nn.Module.")
        if spike_func is None:
            raise TypeError("`spike_func` need to be select in '.surrogate.py`.")
        elif not callable(spike_func):
            raise TypeError("`spike_func` must callable.")
        else:
            """
            If train in different state:
                t: recorded, can be changed by `update_t` method
                bn: recorded, need to use same bn, if use tdbn, need to use same scale 
                    and threshold *************************
                spike_func: need to be same ********************
                slope: recorded, use `update_grad` to reset single instance's spike_grad, and can
                       use  reset_state` to recover all existing instance. ***************
                threshold: recorded, not need other action
                weak_mem: recorded, not need other action
                reset_mechanism: need to be same *******************************
            """
            self.register_buffer('t', torch.tensor(t, dtype=torch.int32))
            self.bn = bn
            self.spike_func = spike_func
            self.register_buffer('slope', torch.tensor(slope, dtype=torch.float))
            self.spike_grad = spike_func(slope)
            self.register_buffer("threshold", torch.tensor(threshold, dtype=torch.float))
            self.register_buffer("weak_mem", torch.tensor(weak_mem, dtype=torch.float))
            self.reset_mechanism = reset_mechanism

            self.mem = 0

            self.fire_num = 0
            self.output_dim = 0

            self.mean = 0
            self.var = 0

    def update_t(self, new_t):
        if isinstance(new_t, torch.Tensor):
            self.t = torch.tensor(new_t.item(), dtype=torch.int32, device=self.t.device)
        else:
            self.t = torch.tensor(new_t, dtype=torch.int32, device=self.t.device)

    def reset_count(self):
        self.mean = 0
        self.var = 0

    def reset_num(self):
        self.fire_num = 0
        self.output_dim = 0

    def update_grad(self):
        self.spike_grad = self.spike_func(self.slope)

    def fire(self):
        mem_shift = self.mem - self.threshold
        spk = self.spike_grad(mem_shift)
        reset = spk.clone().detach()
        return spk, reset

    def statistic(self, reset):
        if Neuron.gate:
            num = reset.sum()
            self.fire_num += num.float()
            if self.output_dim == 0:
                input_num = 1
                size = reset.size()[1:]
                for dim in size:
                    input_num *= dim
                self.output_dim += torch.tensor(input_num, dtype=torch.float, device=reset.device)

    def detach_param(self):
        raise NotImplementedError("`detach_param` method is not complete.")

    def init_param(self):
        raise NotImplementedError("`init_param` method is not complete.")

    def forward(self, x):
        raise NotImplementedError("`forward` method is not complete.")

    @classmethod
    def clear_neuron(cls):
        cls.instance = []
        cls.gate = False
        cls.count_gate = False

    @classmethod
    def open_gate(cls):
        cls.gate = True

    @classmethod
    def close_gate(cls):
        cls.gate = False

    @classmethod
    def open_count_gate(cls):
        cls.count_gate = True

    @classmethod
    def close_count_gate(cls):
        cls.count_gate = False

    @classmethod
    def reset_state(cls, new_t=None):
        for layer in Neuron.instance:
            layer.update_grad()
            if new_t is not None:
                layer.update_t(new_t)


class LIF(Neuron):
    def __init__(self, t, bn, spike_func=None, slope=1.0, threshold=1.0, weak_mem=1.0,
                 reset_mechanism='zero'):
        super(LIF, self).__init__(t, bn, spike_func, slope, threshold, weak_mem, reset_mechanism)

    def detach_param(self):
        if isinstance(self.mem, torch.Tensor):
            self.mem.detach_()

    def init_param(self):
        # self.mem = torch.tensor(0, dtype=torch.float, device=self.mem.device)
        self.mem = 0

    def forward(self, x):
        """
        input: [T * B, C, H, W]
        output: [T * B, C, H, W]
        """
        spike_out = []
        mem_list = []
        if hasattr(self.bn, 'tdbn'):
            out = unfold(self.bn(x), self.t)
            for i in range(self.t):
                self.mem = self.mem * self.weak_mem + out[i]
                spk, reset = self.fire()
                self.statistic(reset)
                if Neuron.gate and Neuron.count_gate:
                    mem_list.append(self.mem.clone().detach())
                spike_out.append(spk)
                if self.reset_mechanism == 'subtract':
                    self.mem = self.mem - reset * self.threshold
                else:
                    self.mem = self.mem * (1 - reset)
            if Neuron.gate and Neuron.count_gate:
                mem_list = torch.stack(mem_list, dim=0)
                self.mean = mem_list.mean()
                self.var = mem_list.var()
            return torch.cat(spike_out, dim=0)
        else:
            out = unfold(x, self.t)
            for i in range(self.t):
                self.mem = self.mem * self.weak_mem + self.bn(out[i])
                spk, reset = self.fire()
                self.statistic(reset)
                if Neuron.gate and Neuron.count_gate:
                    mem_list.append(self.mem.clone().detach())
                spike_out.append(spk)
                if self.reset_mechanism == 'subtract':
                    self.mem = self.mem - reset * self.threshold
                else:
                    self.mem = self.mem * (1 - reset)
            if Neuron.gate and Neuron.count_gate:
                mem_list = torch.stack(mem_list, dim=0)
                self.mean = mem_list.mean()
                self.var = mem_list.var()
            return torch.cat(spike_out, dim=0)

class NLIF(Neuron):
    def __init__(self, t, bn, spike_func=None, slope=1.0, threshold=1.0, weak_mem=1.0,
                 reset_mechanism='zero'):
        super(NLIF, self).__init__(t, bn, spike_func, slope, threshold, weak_mem, reset_mechanism)

    def detach_param(self):
        if isinstance(self.mem, torch.Tensor):
            self.mem.detach_()

    def init_param(self):
        # self.mem = torch.tensor(0, dtype=torch.float, device=self.mem.device)
        self.mem = 0

    def fire(self):
        mem_shift = self.mem - self.threshold
        spk = self.spike_grad(mem_shift)
        reset = spk.detach()
        return spk, reset

    def forward(self, x):
        """
        input: [T * B, C, H, W]
        output: [T * B, C, H, W]
        """
        spike_out = []
        mem_list = []
        if hasattr(self.bn, 'tdbn'):
            out = unfold(self.bn(x), self.t)
            for i in range(self.t):
                self.mem = self.mem * self.weak_mem + out[i]
                spk, reset = self.fire()
                self.statistic(reset)
                if Neuron.gate and Neuron.count_gate:
                    mem_list.append(self.mem.clone().detach())
                spike_out.append(spk)
                if self.reset_mechanism == 'subtract':
                    self.mem = self.mem - spk * self.threshold
                else:
                    self.mem = self.mem * (1 - spk)
            if Neuron.gate and Neuron.count_gate:
                mem_list = torch.stack(mem_list, dim=0)
                self.mean = mem_list.mean()
                self.var = mem_list.var()
            return torch.cat(spike_out, dim=0)
        else:
            out = unfold(x, self.t)
            for i in range(self.t):
                self.mem = self.mem * self.weak_mem + self.bn(out[i])
                spk, reset = self.fire()
                self.statistic(reset)
                if Neuron.gate and Neuron.count_gate:
                    mem_list.append(self.mem.clone().detach())
                spike_out.append(spk)
                if self.reset_mechanism == 'subtract':
                    self.mem = self.mem - spk * self.threshold
                else:
                    self.mem = self.mem * (1 - spk)
            if Neuron.gate and Neuron.count_gate:
                mem_list = torch.stack(mem_list, dim=0)
                self.mean = mem_list.mean()
                self.var = mem_list.var()
            return torch.cat(spike_out, dim=0)


